import torch

a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
b = torch.zeros(3, dtype=torch.uint8)

b = torch.max(torch.Tensor([0,0,0]),torch.Tensor([0,1,1]))
print(1-b)
